/*
 * Copyright (c) 2020 - present, Inspur Genersoft Co., Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.iec.edp.caf.commons.cryptography.providers;

import io.iec.edp.caf.commons.cryptography.model.CafKeyPair;
import io.iec.edp.caf.commons.cryptography.service.AsymmetricCryptoProvider;

import javax.crypto.Cipher;
import java.security.*;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;

/**
 * RSA加解密 验签
 *
 * @author guowenchang
 */
public class RSACryptoProvider implements AsymmetricCryptoProvider {

    public static final int KEY_SIZE = 1024;
    public static final String ALGORITHM = "RSA";
    public static final String SIGNATURE = "SHA1WithRSA";
    public static final String MODE = "RSA/ECB/NoPadding";
    public static final int PADDING_SIZE = 1;

    @Override
    public CafKeyPair generateKeyPair() {
        try {
            //RSA算法要求有一个可信任的随机数源
            SecureRandom secureRandom = new SecureRandom();
            ///为RSA算法创建一个KeyPairGenerator对象
            KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(ALGORITHM);
            // 初始化密钥对生成器，并设置密钥长度
            keyPairGen.initialize(KEY_SIZE, secureRandom);
            // 生成一个密钥对，保存在keyPair中
            KeyPair keyPair = keyPairGen.generateKeyPair();
            return new CafKeyPair(keyPair);
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException("generate key failed.");
        }
    }

    @Override
    public byte[] encrypt(byte[] byteIn, byte[] key,boolean isReverse) {

        if(!isReverse){
            RSAPublicKey publicKey = transformPublicKey(key);
            return doFinal(byteIn, Cipher.ENCRYPT_MODE, publicKey);

        }else{
            RSAPrivateKey privateKey = transformPrivateKey(key);
            return doFinal(byteIn, Cipher.ENCRYPT_MODE, privateKey);

        }
    }

    @Override
    public byte[] decrypt(byte[] byteIn, byte[] key,boolean isReverse) {
        if(!isReverse){
            RSAPrivateKey privateKey = transformPrivateKey(key);
            return doFinal(byteIn, Cipher.DECRYPT_MODE, privateKey);
        }else{
            RSAPublicKey publicKey = transformPublicKey(key);
            return doFinal(byteIn, Cipher.DECRYPT_MODE, publicKey);

        }

    }

    private RSAPublicKey transformPublicKey(byte[] key) {
        RSAPublicKey publicKey;
        try {
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM);
            publicKey = (RSAPublicKey) keyFactory.generatePublic(new X509EncodedKeySpec(key));
            return publicKey;
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException("transform key failed.");
        }
    }

    private RSAPrivateKey transformPrivateKey(byte[] key) {
        RSAPrivateKey privateKey;
        try {
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM);
            privateKey = (RSAPrivateKey) keyFactory.generatePrivate(new PKCS8EncodedKeySpec(key));
            return privateKey;
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException("transform key failed.");
        }
    }

    /**
     * 执行转换
     *
     * @param data 输入数据
     * @param mode 转换模式
     * @param key  key
     * @return
     */
    private byte[] doFinal(byte[] data, int mode, Key key) {
        if (mode == Cipher.ENCRYPT_MODE) {
            int fullBlockSize = KEY_SIZE / 8;
            int blockSize = fullBlockSize - PADDING_SIZE;
            int blockNum = ((data.length - 1) / blockSize + 1);
            int resultSize = blockNum * fullBlockSize;
            byte[] result = new byte[resultSize];

            for (int i = 0; i < blockNum; i++) {
                byte[] blockNow = new byte[fullBlockSize] ;
                System.arraycopy(data, i * blockSize, blockNow, PADDING_SIZE, Math.min(blockSize, data.length - i * blockSize));
                byte[] resultNow = doFinalByGroup(blockNow, mode, key);
                System.arraycopy(resultNow, 0, result, i * fullBlockSize, resultNow.length);
            }

            return result;
        } else {
            int fullBlockSize = KEY_SIZE / 8;
            int blockSize = fullBlockSize - PADDING_SIZE;
            int blockNum = ((data.length - 1) / fullBlockSize + 1);
            int resultSize = blockNum * blockSize;
            byte[] result = new byte[resultSize];

            for (int i = 0; i < blockNum; i++) {
                byte[] blockNow = new byte[fullBlockSize];
                System.arraycopy(data, i * fullBlockSize, blockNow, 0, Math.min(fullBlockSize, data.length - i * fullBlockSize));
                byte[] resultNow = doFinalByGroup(blockNow, mode, key);
                System.arraycopy(resultNow, PADDING_SIZE, result, i * blockSize, resultNow.length - PADDING_SIZE);
            }

            return result;
        }
    }

    /**
     * 分组转换每个区块
     *
     * @param data 当前区块
     * @param mode 转换模式
     * @param key  key
     * @return
     */
    private byte[] doFinalByGroup(byte[] data, int mode, Key key) {
        try {
            Cipher cipher = Cipher.getInstance(MODE);
            cipher.init(mode, key);
            return cipher.doFinal(data);
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException("transform failed.");
        }
    }

    @Override
    public byte[] sign(byte[] byteIn, byte[] key) {
        try {
            // 将byte[]的key格式化为PrivateKey
            PrivateKey privateKey = transformPrivateKey(key);
            // 获取算法实例并初始化
            Signature signature = Signature.getInstance(SIGNATURE);
            signature.initSign(privateKey);
            signature.update(byteIn);
            return signature.sign();
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException("sign failed.");
        }
    }

    @Override
    public boolean verify(byte[] byteIn, byte[] sign, byte[] key) {
        try {
            // 取公钥对象
            PublicKey publicKey = transformPublicKey(key);
            Signature signature = Signature.getInstance(SIGNATURE);
            signature.initVerify(publicKey);
            signature.update(byteIn);
            // 验证签名是否正常
            return signature.verify(sign);
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }
}
